/* Copyright 2023 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/gpu/triton_kernels.h"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/triton_utils.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"

#ifdef JAX_GPU_CUDA
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#endif  // JAX_GPU_CUDA

#ifdef JAX_GPU_HIP
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#endif  // JAX_GPU_HIP

#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))

namespace jax::JAX_GPU_NAMESPACE {
namespace {

constexpr float kBenchmarkTimeMillis = 10.;

struct gpuModuleDeleter {
  void operator()(gpuModule_t module) {
    absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module));
    if (!status.ok()) {
      LOG(WARNING) << "Failed to unload GPU module: " << status;
    }
  }
};

using OwnedGPUmodule =
    std::unique_ptr<std::remove_pointer_t<gpuModule_t>, gpuModuleDeleter>;

absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
  gpuDevice_t device;
#ifdef JAX_GPU_HIP
  int device_id = gpuGetStreamDeviceId(stream);
  GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
#else  // JAX_GPU_CUDA
  gpuContext_t context;
  GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
  GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
  absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
  GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
#endif
  return device;
}

absl::StatusOr<uint32_t> MaxSharedMemoryPerBlock(gpuDevice_t device) {
  int shared_optin;
  GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
      &shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
      device));
  return shared_optin;
}

absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
                                            uint32_t shared_mem_bytes,
                                            std::string_view ptx,
                                            int compute_capability) {
  auto key =
      std::make_tuple(kernel_name, shared_mem_bytes, ptx, compute_capability);

  static absl::Mutex mutex;
  static auto& module_images =
      *new absl::flat_hash_map<decltype(key), std::unique_ptr<ModuleImage>>
          ABSL_GUARDED_BY(mutex);

  absl::MutexLock lock(mutex);
  auto it = module_images.find(key);
  if (it != module_images.end()) return it->second.get();

#ifdef JAX_GPU_HIP  // For HIP/ROCM just read the hsaco file
  std::string result_blob;
  std::string fname{ptx};
  TF_RETURN_IF_ERROR(
      tsl::ReadFileToString(tsl::Env::Default(), fname, &result_blob));
  std::vector<uint8_t> module_image(result_blob.begin(), result_blob.end());
#else
  // TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
  int cc_major = compute_capability / 10;
  int cc_minor = compute_capability % 10;

  bool has_accelerated_features = cc_major >= 9;
  using FeatureExtension =
      stream_executor::CudaComputeCapability::FeatureExtension;
  const stream_executor::CudaComputeCapability cc(
      cc_major, cc_minor,
      has_accelerated_features ? FeatureExtension::kAcceleratedFeatures
                               : FeatureExtension::kNone);
  JAX_ASSIGN_OR_RETURN(
      std::vector<uint8_t> module_image,
      stream_executor::CompileGpuAsm(cc, std::string(ptx),
                                     stream_executor::GpuAsmOpts{}));
#endif

  auto [it2, success] = module_images.insert(
      {std::move(key),
       std::make_unique<ModuleImage>(
           std::move(kernel_name), std::move(module_image), shared_mem_bytes)});
  CHECK(success);
  return it2->second.get();
}

absl::StatusOr<float> Benchmark(gpuStream_t stream, KernelCall& kernel_call,
                                void** buffers, int num_iterations) {
  gpuEvent_t start, stop;
  GPU_RETURN_IF_ERROR(gpuEventCreate(&start, /*Flags=*/GPU_EVENT_DEFAULT));
  GPU_RETURN_IF_ERROR(gpuEventCreate(&stop, /*Flags=*/GPU_EVENT_DEFAULT));
  JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers));  // Warm-up.
  GPU_RETURN_IF_ERROR(gpuEventRecord(start, stream));
  for (int i = 0; i < num_iterations; ++i) {
    JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers));
  }
  GPU_RETURN_IF_ERROR(gpuEventRecord(stop, stream));
  GPU_RETURN_IF_ERROR(gpuEventSynchronize(stop));
  float elapsed_ms;
  GPU_RETURN_IF_ERROR(gpuEventElapsedTime(&elapsed_ms, start, stop));
  GPU_RETURN_IF_ERROR(gpuEventDestroy(start));
  GPU_RETURN_IF_ERROR(gpuEventDestroy(stop));
  return elapsed_ms;
}

absl::StatusOr<KernelCall*> GetKernelCall(std::string_view opaque,
                                          gpuStream_t stream, void** buffers) {
  static absl::Mutex mutex;
  static auto& kernel_calls =
      *new absl::flat_hash_map<std::string,
                               absl::StatusOr<std::unique_ptr<KernelCall>>>
          ABSL_GUARDED_BY(mutex);

  {
    // Fast path uses reader lock (as hash map look-up is relatively slow).
    absl::ReaderMutexLock lock(mutex);
    auto it = kernel_calls.find(opaque);
    if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) {
      JAX_RETURN_IF_ERROR(it->second.status());
      return it->second->get();
    }
  }

  if (opaque.empty()) {
    return absl::InvalidArgumentError("Opaque data is empty.");
  }

  absl::MutexLock lock(mutex);

  auto get_kernel_call = [&]() -> absl::StatusOr<std::unique_ptr<KernelCall>> {
    // The opaque data is a zlib compressed protobuf.
    JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));

    jax_triton::TritonAnyKernelCall proto;
    if (!proto.ParseFromString(serialized)) {
      return absl::InvalidArgumentError("Failed to parse serialized data.");
    }

    if (proto.has_kernel_call()) {
      JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_,
                           KernelCall::FromProto(proto.kernel_call()));
      return std::make_unique<KernelCall>(std::move(kernel_call_));
    } else if (proto.has_autotuned_kernel_call()) {
      JAX_ASSIGN_OR_RETURN(
          AutotunedKernelCall autotuned_call,
          AutotunedKernelCall::FromProto(proto.autotuned_kernel_call()));
      {
        JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_,
                             AutotunedKernelCall::Autotune(
                                 std::move(autotuned_call), stream, buffers));
        return std::make_unique<KernelCall>(std::move(kernel_call_));
      }
    } else {
      return absl::InvalidArgumentError("Unknown kernel call type.");
    }
  };

  // We released the reader lock, so it may have been written by another thread.
  // Create a new entry if it already exists or create a new one.
  auto it = kernel_calls.emplace(std::string(opaque), get_kernel_call()).first;

  JAX_RETURN_IF_ERROR(it->second.status());
  return it->second->get();
}

}  // namespace

class ModuleImage {
 public:
  ModuleImage(std::string_view kernel_name, std::vector<uint8_t> module_image,
              uint32_t shared_mem_bytes)
      : kernel_name_(kernel_name),
        module_image_(std::move(module_image)),
        shared_mem_bytes_(shared_mem_bytes) {}

  absl::StatusOr<gpuFunction_t> GetFunctionForContext(gpuContext_t context) {
    absl::MutexLock lock(mutex_);
    auto it = functions_.find(context);
    if (ABSL_PREDICT_TRUE(it != functions_.end())) {
      return it->second;
    }

    GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
    absl::Cleanup ctx_restorer = [] {
      absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr));
      if (!status.ok()) {
        LOG(WARNING) << "Failed to pop GPU context: " << status;
      }
    };

    gpuModule_t module;
    GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data()));
    modules_.push_back(OwnedGPUmodule(module, gpuModuleDeleter()));

    gpuFunction_t function;
    GPU_RETURN_IF_ERROR(
        gpuModuleGetFunction(&function, module, kernel_name_.c_str()));
    auto [_, success] = functions_.insert({context, function});
    CHECK(success);

    // The maximum permitted static shared memory allocation in CUDA is 48kB,
    // but we can expose more to the kernel using dynamic shared memory.
    constexpr int kMaxStaticSharedMemBytes = 49152;
    if (shared_mem_bytes_ <= kMaxStaticSharedMemBytes) {
      return function;
    }

    // Set up dynamic shared memory.
    gpuDevice_t device;
    GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));

    int shared_optin;
    GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
        &shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
        device));

    if (shared_mem_bytes_ > shared_optin) {
      return absl::InvalidArgumentError(absl::StrFormat(
          "Shared memory requested (%d b) exceeds device resources (%d b).",
          shared_mem_bytes_, shared_optin));
    }

    if (shared_optin > kMaxStaticSharedMemBytes) {
#ifdef JAX_GPU_CUDA
      GPU_RETURN_IF_ERROR(
          gpuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
#endif
      int shared_total;
      GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
          &shared_total,
          GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
      int shared_static;
      GPU_RETURN_IF_ERROR(gpuFuncGetAttribute(
          &shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
#ifdef JAX_GPU_CUDA
      GPU_RETURN_IF_ERROR(cuFuncSetAttribute(
          function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
          shared_optin - shared_static));
#endif
    }
    return function;
  }

 private:
  std::string kernel_name_;
  std::vector<uint8_t> module_image_;
  uint32_t shared_mem_bytes_;

  absl::Mutex mutex_;
  std::vector<OwnedGPUmodule> modules_ ABSL_GUARDED_BY(mutex_);
  absl::flat_hash_map<gpuContext_t, gpuFunction_t> functions_
      ABSL_GUARDED_BY(mutex_);
};

Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
               uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
               int compute_capability, uint32_t cluster_dim_0,
               uint32_t cluster_dim_1, uint32_t cluster_dim_2)
    : kernel_name_(std::move(kernel_name)),
      block_dim_x_(num_warps * kNumThreadsPerWarp),
      shared_mem_bytes_(shared_mem_bytes),
      ptx_(std::move(ptx)),
      ttir_(std::move(ttir)),
      compute_capability_(compute_capability),
      cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}

absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
                            void** params) {
  if (ABSL_PREDICT_FALSE(module_image_ == nullptr)) {
    JAX_ASSIGN_OR_RETURN(module_image_,
                         GetModuleImage(kernel_name_, shared_mem_bytes_, ptx_,
                                        compute_capability_));
  }

  gpuContext_t context;
#ifdef JAX_GPU_HIP
  int device_id = gpuGetStreamDeviceId(stream);
  gpuDevice_t device;
  GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
  GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
  JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
                       module_image_->GetFunctionForContext(context));
  return JAX_AS_STATUS(gpuLaunchKernel(
      kernel, grid[0], grid[1], grid[2], block_dim_x_,
      /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
      /*extra=*/nullptr));
#else  // JAX_GPU_CUDA
  // TODO(b/324319767): A bug in CUDA prevents us from calling cuStreamGetCtx
  // inside graph capture. We use cuCtxGetCurrent as a workaround here because
  // context is not updated, but we should change it back to cuStreamGetCtx once
  // the bug is fixed.
  gpustreamCaptureStatus_t capture_status;
  GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
  if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
    GPU_RETURN_IF_ERROR(gpuCtxGetCurrent(&context));
  } else {
    GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
  }

  JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
                       module_image_->GetFunctionForContext(context));
  const uint32_t cluster_size =
      cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2];
  if (cluster_size <= 1) {
    return JAX_AS_STATUS(gpuLaunchKernel(
        kernel, grid[0], grid[1], grid[2], block_dim_x_,
        /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
        /*extra=*/nullptr));
  }
  CUlaunchAttribute launch_attrs[2];
  launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
  launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
  launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
  launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
  launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
  launch_attrs[1].value.clusterSchedulingPolicyPreference =
      CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
  CUlaunchConfig launch_config = {
      /*gridDimX=*/grid[0] * cluster_dims_[0],
      /*gridDimY=*/grid[1] * cluster_dims_[1],
      /*gridDimZ=*/grid[2] * cluster_dims_[2],
      /*blockDimX=*/block_dim_x_,
      /*blockDimY=*/1,
      /*blockDimZ=*/1,
      /*sharedMemBytes=*/shared_mem_bytes_,
      /*hStream=*/stream,
      /**attrs=*/launch_attrs,
      /*numAttrs=*/2,
  };
  return JAX_AS_STATUS(
      cuLaunchKernelEx(&launch_config, kernel, params, /*extra=*/nullptr));
#endif
}

/*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) {
  return Kernel(proto.kernel_name(), proto.num_warps(),
                proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
                proto.compute_capability(), proto.cluster_dim_0(),
                proto.cluster_dim_1(), proto.cluster_dim_2());
}

jax_triton::TritonKernel Kernel::ToProto() const {
  jax_triton::TritonKernel proto;
  proto.set_kernel_name(kernel_name_);
  proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
  proto.set_shared_mem_bytes(shared_mem_bytes_);
  proto.set_ptx(ptx_);
  proto.set_ttir(ttir_);
  proto.set_compute_capability(compute_capability_);
  proto.set_cluster_dim_0(cluster_dims_[0]);
  proto.set_cluster_dim_1(cluster_dims_[1]);
  proto.set_cluster_dim_2(cluster_dims_[2]);
  return proto;
}

/*static*/ absl::StatusOr<KernelCall::Parameter>
KernelCall::Parameter::FromProto(
    const jax_triton::TritonKernelCall_Parameter& proto) {
  using jax_triton::TritonKernelCall_Parameter;
  Parameter param;
  switch (proto.value_case()) {
    case TritonKernelCall_Parameter::kArray:
      param.value = Array{proto.array().bytes_to_zero(),
                          proto.array().ptr_divisibility()};
      break;
    case TritonKernelCall_Parameter::kBool:
      param.value = proto.bool_();
      break;
    case TritonKernelCall_Parameter::kI32:
      param.value = proto.i32();
      break;
    case TritonKernelCall_Parameter::kU32:
      param.value = proto.u32();
      break;
    case TritonKernelCall_Parameter::kI64:
      param.value = proto.i64();
      break;
    case TritonKernelCall_Parameter::kU64:
      param.value = proto.u64();
      break;
    case TritonKernelCall_Parameter::kF32:
      param.value = proto.f32();
      break;
    case TritonKernelCall_Parameter::kF64:
      param.value = proto.f64();
      break;
    default:
      return absl::InvalidArgumentError("Unknown scalar parameter type.");
  }
  return param;
}

bool Kernel::CanLaunchOnDevice(gpuDevice_t device) const {
  return shared_mem_bytes_ <= MaxSharedMemoryPerBlock(device).value_or(0);
}

jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
  jax_triton::TritonKernelCall_Parameter proto;
  if (std::holds_alternative<Array>(value)) {
    proto.mutable_array()->set_bytes_to_zero(
        std::get<Array>(value).bytes_to_zero);
    proto.mutable_array()->set_ptr_divisibility(
        std::get<Array>(value).ptr_divisibility);
  } else if (std::holds_alternative<bool>(value)) {
    proto.set_bool_(std::get<bool>(value));
  } else if (std::holds_alternative<int32_t>(value)) {
    proto.set_i32(std::get<int32_t>(value));
  } else if (std::holds_alternative<uint32_t>(value)) {
    proto.set_u32(std::get<uint32_t>(value));
  } else if (std::holds_alternative<int64_t>(value)) {
    proto.set_i64(std::get<int64_t>(value));
  } else if (std::holds_alternative<uint64_t>(value)) {
    proto.set_u64(std::get<uint64_t>(value));
  } else if (std::holds_alternative<float>(value)) {
    proto.set_f32(std::get<float>(value));
  } else {
    CHECK(std::holds_alternative<double>(value));
    proto.set_f64(std::get<double>(value));
  }
  return proto;
}

KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1,
                       uint32_t grid_2, std::vector<Parameter> parameters)
    : kernel_(std::move(kernel)),
      grid_{grid_0, grid_1, grid_2},
      parameters_(std::move(parameters)) {}

absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
  std::vector<void*> params;
  // We need an additional parameter for the scratchpad buffer.
  params.reserve(parameters_.size() + 1);
  for (size_t i = 0; i < parameters_.size(); ++i) {
    const Parameter& param = parameters_[i];
    if (std::holds_alternative<Parameter::Array>(param.value)) {
      const auto& array = std::get<Parameter::Array>(param.value);
      void*& ptr = *(buffers++);
      auto cu_ptr = reinterpret_cast<gpuDevicePtr_t>(ptr);

      if (ABSL_PREDICT_FALSE((array.ptr_divisibility != 0) &&
                             ((size_t)cu_ptr % array.ptr_divisibility != 0))) {
        return absl::InvalidArgumentError(
            absl::StrFormat("Parameter %zu (%zu) is not divisible by %d.", i,
                            (size_t)ptr, array.ptr_divisibility));
      }

      if (array.bytes_to_zero > 0) {
        GPU_RETURN_IF_ERROR(
            gpuMemsetD8Async(cu_ptr, 0, array.bytes_to_zero, stream));
      }
      params.push_back(&ptr);
    } else {
      params.push_back(const_cast<void*>(std::visit(
          [](auto&& arg) { return reinterpret_cast<const void*>(&arg); },
          param.value)));
    }
  }
  // Triton's kernel ABI expects an additional scratchpad global memory.
  // For now it is only used for on-device creation of TMA descriptors, which
  // we do not use yet, so we are just replacing this argument with a null
  // pointer.
  // TODO: b/381242007 - Allocate a proper buffer if we want to use
  // device-side TMA APIs.
  void* tma_descriptor_buffer = nullptr;  // Alive until kernel_.Launch returns.
  params.push_back(&tma_descriptor_buffer);
  void* profiling_buffer = nullptr;  // Alive until kernel_.Launch returns.
  params.push_back(&profiling_buffer);

  return kernel_.Launch(stream, grid_, params.data());
}

/*static*/ absl::StatusOr<KernelCall> KernelCall::FromProto(
    const jax_triton::TritonKernelCall& proto) {
  std::vector<KernelCall::Parameter> parameters;
  for (const jax_triton::TritonKernelCall_Parameter& parameter :
       proto.parameters()) {
    JAX_ASSIGN_OR_RETURN(Parameter p, Parameter::FromProto(parameter));
    parameters.push_back(p);
  }

  return KernelCall(Kernel::FromProto(proto.kernel()), proto.grid_0(),
                    proto.grid_1(), proto.grid_2(), std::move(parameters));
}

jax_triton::TritonKernelCall KernelCall::ToProto() const {
  jax_triton::TritonKernelCall proto;
  *proto.mutable_kernel() = kernel_.ToProto();
  proto.set_grid_0(grid_[0]);
  proto.set_grid_1(grid_[1]);
  proto.set_grid_2(grid_[2]);
  for (const Parameter& param : parameters_) {
    *proto.add_parameters() = param.ToProto();
  }
  return proto;
}

bool KernelCall::CanLaunchOnDevice(gpuDevice_t device) const {
  return kernel_.CanLaunchOnDevice(device);
}

AutotunedKernelCall::AutotunedKernelCall(
    std::string name, std::vector<Config> configs,
    std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases)
    : name_(std::move(name)),
      configs_(std::move(configs)),
      input_output_aliases_(std::move(input_output_aliases)) {}

/*static*/ absl::StatusOr<AutotunedKernelCall> AutotunedKernelCall::FromProto(
    const jax_triton::TritonAutotunedKernelCall& proto) {
  std::vector<Config> configs;
  for (const jax_triton::TritonAutotunedKernelCall_Config& config :
       proto.configs()) {
    JAX_ASSIGN_OR_RETURN(auto kernel_call,
                         KernelCall::FromProto(config.kernel_call()));
    configs.push_back(Config{std::move(kernel_call), config.description()});
  }

  std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases;
  for (const jax_triton::TritonAutotunedKernelCall_InputOutputAlias& a :
       proto.input_output_aliases()) {
    input_output_aliases.push_back(std::make_tuple(
        a.input_buffer_idx(), a.output_buffer_idx(), a.buffer_size_bytes()));
  }

  return AutotunedKernelCall(proto.name(), std::move(configs),
                             std::move(input_output_aliases));
}

jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
  jax_triton::TritonAutotunedKernelCall proto;
  proto.set_name(name_);
  for (const Config& config : configs_) {
    jax_triton::TritonAutotunedKernelCall_Config* c = proto.add_configs();
    *c->mutable_kernel_call() = config.kernel_call.ToProto();
    c->set_description(config.description);
  }
  for (const auto& [input_idx, output_idx, size] : input_output_aliases_) {
    jax_triton::TritonAutotunedKernelCall_InputOutputAlias* a =
        proto.add_input_output_aliases();
    a->set_input_buffer_idx(input_idx);
    a->set_output_buffer_idx(output_idx);
    a->set_buffer_size_bytes(size);
  }
  return proto;
}

/*static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune(
    AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) {
  // Ensure a valid context for driver calls that don't take the stream.
  // gpuContext_t context;
  // GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
  // GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
  // absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };

  // Autotuning is not supported if the stream is in graph capture mode.
  gpustreamCaptureStatus_t capture_status;
  GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
  if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
    return absl::FailedPreconditionError(
        "Can't autotune Triton kernel when the stream is in graph capture "
        "mode. Autotuning can rely on real data present in input buffers to "
        "use them in address computation, but in graph capture mode buffers "
        "can have arbitrary data");
  }

  // If an input aliases with an output, it will get overwritten during the
  // kernel execution. If the kernel is called repeatedly, as we do during
  // auto-tuning, the final result will be junk, so we take a copy of the
  // input to restore after auto-tuning.
  std::unordered_map<size_t, std::vector<uint8_t>> input_copies;
  for (auto [input_idx, output_idx, size] : kernel_call.input_output_aliases_) {
    if (buffers[input_idx] == buffers[output_idx]) {
      std::vector<uint8_t> input_copy(size);
      GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
          input_copy.data(),
          reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size, stream));
      input_copies[input_idx] = std::move(input_copy);
    }
  }

  LOG(INFO) << "Autotuning function: " << kernel_call.name_;
  // First run a single iteration of each to config to determine how many
  // iterations to run for benchmarking.
  float best = std::numeric_limits<float>::infinity();
  JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
  absl::flat_hash_set<Config*> configs_to_skip;
  for (Config& config : kernel_call.configs_) {
    if (!config.kernel_call.CanLaunchOnDevice(device)) {
      configs_to_skip.insert(&config);
      continue;
    }
    JAX_ASSIGN_OR_RETURN(float t,
                         Benchmark(stream, config.kernel_call, buffers, 1));
    LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
    best = std::min(best, t);
  }

  int timed_iters = std::max(static_cast<int>(kBenchmarkTimeMillis / best), 1);
  if (timed_iters > 100) {
    timed_iters = 100;
    LOG(INFO) << "Benchmarking with 100 iters (capped at 100)";
  } else {
    timed_iters = std::min(timed_iters, 100);
    LOG(INFO) << "Benchmarking with " << timed_iters
              << " iters (target time: " << kBenchmarkTimeMillis << " ms)";
  }

  best = std::numeric_limits<float>::infinity();
  for (Config& config : kernel_call.configs_) {
    if (configs_to_skip.contains(&config)) {
      LOG(WARNING) << "Unable to launch autotune config on device: "
                   << config.description;
      continue;
    }

    JAX_ASSIGN_OR_RETURN(
        float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
    LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
              << t << " ms";

    if (t < best) {
      LOG(INFO) << config.description << " is the new best config";
      best = t;
      std::swap(config, kernel_call.configs_[0]);
    }
  }
  if (std::isinf(best)) {
    LOG(WARNING) << "Finished autotuning function: " << kernel_call.name_
                 << " no valid configs found.";
    return absl::FailedPreconditionError("No launchable configs.");
  }

  LOG(INFO) << "Finished autotuning function: " << kernel_call.name_
            << " best config " << kernel_call.configs_[0].description;

  // Restore aliased inputs to their original values.
  for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
    GPU_RETURN_IF_ERROR(
        gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
                           input_copies[input_idx].data(), size, stream));
  }

  // Synchronize stream to ensure copies are complete before the host copy
  // is deleted.
  GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));

  return std::move(kernel_call.configs_[0].kernel_call);
}

void TritonKernelCall(gpuStream_t stream, void** buffers, const char* opaque,
                      size_t opaque_len, XlaCustomCallStatus* status) {
  absl::Status result = [=] {
    JAX_ASSIGN_OR_RETURN(
        KernelCall * kernel_call,
        GetKernelCall(std::string_view(opaque, opaque_len), stream, buffers));
    return kernel_call->Launch(stream, buffers);
  }();
  if (!result.ok()) {
    std::string_view msg = result.message();
    XlaCustomCallStatusSetFailure(status, msg.data(), msg.length());
  }
}

}  // namespace jax::JAX_GPU_NAMESPACE
